library(tidyverse)
library(tidymodels)
library(textrecipes)23 Регрессионные модели с tidymodels
23.1 Регрессионные алгоритмы
В машинном обучении проблемы, связанные с количественным откликом, называют проблемами регрессии, а проблемы, связанные с качественным откликом, проблемами классификации. В прошлом уроке мы познакомились с простой и множественной регрессией, но регрессионных алгоритмов великое множество. Вот лишь некоторые из них:
полиномиальная регрессия: расширение линейной регрессии, позволяющее учитывать нелинейные зависимости.
логистическая регрессия: используется для прогнозирования категориальных (бинарных) откликов.
регрессия на опорных векторах (SVM): ищет гиперплоскость, позволяющую минимизировать ошибку в многомерном пространстве.
деревья регрессии: строят иерархическую древовидную модель, последовательно разбивая данные на подгруппы.
случайный лес: комбинирует предсказания множества деревьев для повышения точности и устойчивости.
Кроме того, существуют методы регуляризации линейных моделей, позволяющие существенно улучшить их качество на данных большой размерности (т.е. с большим количеством предкторов). К таким алгоритмам относятся гребневая регрессия и метод лассо. О них мы поговорим в одном из следующих уроков.
О математической стороне дела см. Г. Джеймс, Д. Уиттон, Т. Хасти, Р. Тибришани (2017). В этом уроке мы научимся работать с различными регрессионными алгоритмами, используя библиотеку tidymodels.
23.2 Библиотека tidymodels
Библиотека tidymodels позволяет обучать модели и оценивать их эффективность с использованием принципов опрятных данных. Она представляет собой набор пакетов R, которые разработаны для работы с машинным обучением и являются частью более широкой экосистемы tidyverse.
Вот некоторые из ключевых пакетов, входящих в состав tidymodels:
parsnip- универсальный интерфейс для различных моделей машинного обучения, который упрощает переключение между разными типами моделей;recipes- фреймворк для создания и управления “рецептами” предварительной обработки данных перед тренировкой модели;rsample- инструменты для разделения данных на обучающую и тестовую выборки, а также для кросс-валидации;tune- функции для оптимизации гиперпараметров моделей машинного обучения;yardstick- инструменты для оценки производительности моделей;workflowпозволяет объединить различные компоненты модели в единый объект: препроцессинг данных, модель машинного обучения, настройку гиперпараметров.
Мы также будем использовать пакет textrecipes, который представляет собой аналог recipes для текстовых данных.
23.3 Данные
Датасет для этого урока хранит данные о названиях, рейтингах, жанре, цене и числе отзывов на некоторые книги с Amazon. Мы попробуем построить регресионную модель, которая будет предсказывать цену книги.
books <- readxl::read_xlsx("../files/AmazonBooks.xlsx")
booksДанные не очень опрятны, и прежде всего их надо тайдифицировать.
colnames(books) <- tolower(colnames(books))
books <- books |>
rename(rating = `user rating`)На графике ниже видно, что сильной корреляции между количественными переменными не прослеживается, так что задача перед нами стоит незаурядная. Посмотрим, что можно сделать в такой ситуации.
books |>
select_if(is.numeric) |>
cor() |>
corrplot::corrplot(method = "ellipse")
Мы видим, что количественные предикторы объясняют лишь ничтожную долю дисперсии (чуть более информативен жанр).
summary(lm(price ~ reviews + year + rating + genre, data = books))
Call:
lm(formula = price ~ reviews + year + rating + genre, data = books)
Residuals:
Min 1Q Median 3Q Max
-16.472 -5.050 -1.841 2.307 89.686
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 8.987e+02 2.734e+02 3.287 0.00107 **
reviews 7.779e-07 3.181e-05 0.024 0.98050
year -4.324e-01 1.370e-01 -3.156 0.00168 **
rating -3.655e+00 1.933e+00 -1.891 0.05909 .
genreNon Fiction 3.920e+00 8.669e-01 4.522 7.41e-06 ***
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Residual standard error: 10.16 on 595 degrees of freedom
Multiple R-squared: 0.06903, Adjusted R-squared: 0.06277
F-statistic: 11.03 on 4 and 595 DF, p-value: 1.235e-08
Посмотрим, можно ли как-то улучшить этот результат. Но сначала оценим визуально связь между ценой, с одной стороны, и годом и жанром, с другой.
g1 <- books |>
ggplot(aes(year, price, color = genre, group = genre)) +
geom_jitter(show.legend = FALSE, alpha = 0.7) +
geom_smooth(method = "lm", se = FALSE) +
theme_minimal()
g2 <- books |>
ggplot(aes(genre, price, color = genre)) +
geom_boxplot() +
theme_minimal()
gridExtra::grid.arrange(g1, g2, nrow = 1)
23.4 Обучающая и контрольная выборка
Вы уже знаете, при обучении модели мы стремимся к минимизации среднеквадратичной ошибки (MSE), однако в большинстве случаев нас интересует не то, как метод работает на обучающих данных, а то, как он покажет себя на контрольных данных. Чтобы избежать переобучения, очень важно в самом начале разделить доступные наблюдения на две группы.
books_split <- books |>
initial_split()
books_train <- training(books_split)
books_test <- testing(books_split)23.5 Определение модели
Определение модели включает следующие шаги:
указывается тип модели на основе ее математической структуры (например, линейная регрессия, случайный лес, KNN и т. д.);
указывается механизм для подгонки модели – чаще всего это программный пакет, который должен быть использован, например
glmnet. Это самостоятельные модели, иparsnipобеспечивает согласованные интерфейсы, используя их в качестве движков для моделирования.при необходимости объявляется режим модели. Режим отражает тип прогнозируемого результата. Для числовых результатов режимом является регрессия, для качественных - классификация. Если алгоритм модели может работать только с одним типом результатов прогнозирования, например, линейной регрессией, режим уже задан.
23.6 Регрессия на опорных векторах
Начнем с регрессии на опорных векторах. Функция translate() позволяет понять, как parsnip переводит пользовательский код на язык пакета.
svm_spec <- svm_linear() |>
set_engine("LiblineaR") |>
set_mode("regression")
svm_spec |>
translate()Linear Support Vector Machine Model Specification (regression)
Computational engine: LiblineaR
Model fit template:
LiblineaR::LiblineaR(x = missing_arg(), y = missing_arg(), type = 11,
svr_eps = 0.1)
Пока это просто спецификация модели без данных и без формулы. Добавим ее к воркфлоу.
svm_wflow <- workflow() |>
add_model(svm_spec)
svm_wflow══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: None
Model: svm_linear()
── Model ───────────────────────────────────────────────────────────────────────
Linear Support Vector Machine Model Specification (regression)
Computational engine: LiblineaR
23.7 Дизайн переменных
Теперь нам нужен препроцессор. За него отвечает пакет recipes. Если вы не уверены, какие шаги необходимы на этом этапе, можно заглянуть в шпаргалку. В случае с линейной регрессией это может быть логарифмическая трансформация, нормализация, отсев переменных с нулевой дисперсией (zero variance), добавление (impute) недостающих значений или удаление переменных, которые коррелируют с другими переменными.
Вот так выглядит наш первый рецепт. Обратите внимание, что формула записывается так же, как мы это делали ранее внутри функции lm().
books_rec <- recipe(price ~ year + genre + name,
data = books_train) |>
step_dummy(genre) |>
step_normalize(year) |>
step_tokenize(name) |>
step_tokenfilter(name, max_tokens = 1000) |>
step_tfidf(name) При желании можно посмотреть на результат предобработки.
prep(books_rec, books_train) |>
bake(new_data = NULL)Добавляем препроцессор в воркфлоу.
svm_wflow <- svm_wflow |>
add_recipe(books_rec)
svm_wflow══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: svm_linear()
── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps
• step_dummy()
• step_normalize()
• step_tokenize()
• step_tokenfilter()
• step_tfidf()
── Model ───────────────────────────────────────────────────────────────────────
Linear Support Vector Machine Model Specification (regression)
Computational engine: LiblineaR
23.8 Подгонка модели
Подгоним модель на обучающих данных.
svm_fit <- svm_wflow |>
fit(data = books_train)Пакет broom позволяет тайдифицировать модель. Посмотрим на слова, которые приводят к “удорожанию” книг. Видно, что в начале списка – слова, связанные с научными публикациями, что не лишено смысла.
svm_fit |>
tidy() |>
arrange(-estimate)Оценим модель на контрольных данных.
pred_data <- tibble(truth = books_test$price,
estimate = predict(svm_fit, books_test)$.pred)
books_metrics <- metric_set(rmse, rsq, mae)
books_metrics(pred_data, truth = truth, estimate = estimate)23.9 Повторные выборки
Чтобы не распечатывать каждый раз тестовые данные (в идеале мы их используем один, максимум два раза!), задействуется ряд методов, позволяющих оценить ошибку путем исключения части обучающих наблюдений из процесса подгонки модели и последующего применения этой модели к исключенным наблюдениям.
В пакете rsample из библиотеки tidymodels реализованы, среди прочего, следующие методы повторных выборок для оценки производительности моделей машинного обучения:
Метод проверочной выборки – набор наблюдений делится на обучающую и проверочную, или удержанную, выборку (validation set): для этого используется
initial_validation_split().K-кратная перекрестная проверка – наблюдения разбиваются на k групп примерно одинакового размера, первый блок служит в качестве проверочной выборки, а модель подгоняется по остальным k-1 блокам; процедура повторяется k раз: функция
vfold_cv().Перекрестная проверка Монте-Карло – в отличие от предыдущего метода, создается множество случайных разбиений данных на обучающую и тестовую выборки: функция
mc_cv().Бутстреп – отбор наблюдений выполняется с возвращением, т.е. одно и то же наблюдение может встречаться несколько раз: функция
bootstraps().Перекрестная проверка по отдельным наблюдениям (leave-one-out сross-validation): одно наблюдение используется в качестве контрольного, а остальные составляют обучающую выборку; модель подгоняется по n-1 наблюдениям, что повторяется n раз: функция
loo_cv().
Эти методы повторных выборок позволяют получить надежные оценки производительности моделей машинного обучения, избегая переобучения и обеспечивая репрезентативность тестовых выборок.
set.seed(05102024)
books_folds <- vfold_cv(books_train, v = 10)
set.seed(05102024)
svm_rs <- fit_resamples(
svm_wflow,
books_folds,
control = control_resamples(save_pred = TRUE)
)Теперь соберем метрики и убедимся, что предыдущая оценка на контрольных данных была слишком оптимистичной. Однако результат не так уж плох: во всяком случае мы смогли добиться заметного улучшения по сравнению с нулевой моделью.
collect_metrics(svm_rs)svm_rs |>
collect_predictions() |>
ggplot(aes(price, .pred, color = id)) +
geom_jitter(alpha = 0.3) +
geom_abline(lty = 2, color = "grey80") +
theme_minimal() +
coord_cartesian(xlim = c(0,50), ylim = c(0,50))
23.10 Нулевая модель
Кстати, проверим, какой результат даст нулевая модель.
null_reg <- null_model() |>
set_engine("parsnip") |>
set_mode("regression")
null_wflow <- workflow() |>
add_model(null_reg) |>
add_recipe(books_rec)
null_rs <- fit_resamples(
null_wflow,
books_folds,
control = control_resamples(save_pred = TRUE)
)→ A | warning: A correlation computation is required, but `estimate` is constant and has 0
standard deviation, resulting in a divide by 0 error. `NA` will be returned.
There were issues with some computations A: x1
There were issues with some computations A: x10
collect_metrics(null_rs)\(R^2\) в таком случае должен быть NaN.
23.11 Случайный лес
Уточним, какие движки доступны для случайных лесов.
show_engines("rand_forest")Создадим спецификацию модели. Деревья используются как в задачах классификации, так и в задачах регрессии, поэтому задействуем функцию set_mode().
rf_spec <- rand_forest(trees = 1000) |>
set_engine("ranger") |>
set_mode("regression")rf_wflow <- workflow() |>
add_model(rf_spec) |>
add_recipe(books_rec)
rf_wflow══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()
── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps
• step_dummy()
• step_normalize()
• step_tokenize()
• step_tokenfilter()
• step_tfidf()
── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (regression)
Main Arguments:
trees = 1000
Computational engine: ranger
Обучение займет чуть больше времени.
rf_rs <- fit_resamples(
rf_wflow,
books_folds,
control = control_resamples(save_pred = TRUE)
)Мы видим, что среднеквадратическая ошибка уменьшилась, а доля объясненной дисперсии выросла.
collect_metrics(rf_rs)Тем не менее на графике можно заметить нечто странное: наша модель систематически переоценивает низкие значения и недооценивает высокие. Это связано с тем, что случайные леса не очень подходят для работы с разреженными данными (Hvitfeldt и Silge 2022).
rf_rs |>
collect_predictions() |>
ggplot(aes(price, .pred, color = id)) +
geom_jitter(alpha = 0.3) +
geom_abline(lty = 2, color = "grey80") +
theme_minimal() +
coord_cartesian(xlim = c(0, 50), ylim = c(0, 50))
23.12 Градиентные бустинговые деревья
Также попробуем построить регрессию с использованием градиентных бустинговых деревьев. Это один из алгоритмов ансамблевого машинного обучения, который строит последовательность простых моделей решающих деревьев, каждая из которых работает над ошибками предыдущей. В 2023 г. эта техника показала хорошие результаты в эксперименте по датировке греческих документальных папирусов.
xgb_spec <-
boost_tree(mtry = 50, trees = 1000) |>
set_engine("xgboost") %>%
set_mode("regression")xgb_wflow <- workflow() |>
add_model(xgb_spec) |>
add_recipe(books_rec)
xgb_wflow══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: boost_tree()
── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps
• step_dummy()
• step_normalize()
• step_tokenize()
• step_tokenfilter()
• step_tfidf()
── Model ───────────────────────────────────────────────────────────────────────
Boosted Tree Model Specification (regression)
Main Arguments:
mtry = 50
trees = 1000
Computational engine: xgboost
Проводим перекрестную проверку.
xgb_rs <- fit_resamples(
xgb_wflow,
books_folds,
control = control_resamples(save_pred = TRUE)
)collect_metrics(xgb_rs)Метрики неплохие! Но если взглянуть на остатки, можно увидеть что-то вроде буквы S.
rf_rs |>
collect_predictions() |>
ggplot(aes(price, .pred, color = id)) +
geom_jitter(alpha = 0.3) +
geom_abline(lty = 2, color = "grey80") +
theme_minimal() +
coord_cartesian(xlim = c(0, 50), ylim = c(0, 50))
23.13 Удаление стопслов
Изменим рецепт приготовления данных.
stopwords_rec <- function(stopwords_name) {
recipe(price ~ year + genre + name, data = books_train) |>
step_dummy(genre) |>
step_normalize(year) |>
step_tokenize(name) |>
step_stopwords(name, stopword_source = stopwords_name) |>
step_tokenfilter(name, max_tokens = 1000) |>
step_tfidf(name)
}Создадим воркфлоу.
svm_wflow <- workflow() |>
add_model(svm_spec)И снова проведем перекрестную проверку, на этот раз с разными списками стоп-слов. На этом шаге команда вернет предупреждения о том, что число слов меньше 1000, это нормально, т.к. после удаления стопслов токенов стало меньше.
set.seed(123)
snowball_rs <- fit_resamples(
svm_wflow |> add_recipe(stopwords_rec("snowball")),
books_folds
)
set.seed(234)
smart_rs <- fit_resamples(
svm_wflow |> add_recipe(stopwords_rec("smart")),
books_folds
)
set.seed(345)
stopwords_iso_rs <- fit_resamples(
svm_wflow |> add_recipe(stopwords_rec("stopwords-iso")),
books_folds
)collect_metrics(smart_rs)collect_metrics(snowball_rs)collect_metrics((stopwords_iso_rs))В нашем случае удаление стоп-слов положительного эффекта не имело.
word_counts <- tibble(name = c("snowball", "smart", "stopwords-iso")) %>%
mutate(words = map_int(name, ~length(stopwords::stopwords(source = .))))
list(snowball = snowball_rs,
smart = smart_rs,
`stopwords-iso` = stopwords_iso_rs) |>
map_dfr(show_best, metric = "rmse", .id = "name") |>
left_join(word_counts, by = "name") |>
mutate(name = paste0(name, " (", words, " words)"),
name = fct_reorder(name, words)) |>
ggplot(aes(name, mean, color = name)) +
geom_crossbar(aes(ymin = mean - std_err, ymax = mean + std_err), alpha = 0.6) +
geom_point(size = 3, alpha = 0.8) +
theme(legend.position = "none") +
theme_minimal()
23.14 Настройки числа n-grams
ngram_rec <- function(ngram_options) {
recipe(price ~ year + genre + name, data = books_train) |>
step_dummy(genre) |>
step_normalize(year) |>
step_tokenize(name, token = "ngrams", options = ngram_options) |>
step_tokenfilter(name, max_tokens = 1000) |>
step_tfidf(name)
}fit_ngram <- function(ngram_options) {
fit_resamples(
svm_wflow %>% add_recipe(ngram_rec(ngram_options)),
books_folds
)
}set.seed(123)
unigram_rs <- fit_ngram(list(n = 1))
set.seed(234)
bigram_rs <- fit_ngram(list(n = 2, n_min = 1))
set.seed(345)
trigram_rs <- fit_ngram(list(n = 3, n_min = 1))collect_metrics(unigram_rs)collect_metrics(bigram_rs)collect_metrics(trigram_rs)Таким образом, униграмы дают лучший результат:
list(`1` = unigram_rs,
`1 and 2` = bigram_rs,
`1, 2, and 3` = trigram_rs) |>
map_dfr(collect_metrics, .id = "name") |>
filter(.metric == "rmse") |>
ggplot(aes(name, mean, color = name)) +
geom_crossbar(aes(ymin = mean - std_err, ymax = mean + std_err),
alpha = 0.6) +
geom_point(size = 3, alpha = 0.8) +
theme(legend.position = "none") +
labs(
y = "RMSE"
) +
theme_minimal()
23.15 Лучшая модель и оценка
svm_fit <- svm_wflow |>
add_recipe(books_rec) |>
fit(data = books_test)Warning: max_tokens was set to 1000, but only 570 was available and selected.
svm_fit══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: svm_linear()
── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps
• step_dummy()
• step_normalize()
• step_tokenize()
• step_tokenfilter()
• step_tfidf()
── Model ───────────────────────────────────────────────────────────────────────
$TypeDetail
[1] "L2-regularized L2-loss support vector regression primal (L2R_L2LOSS_SVR)"
$Type
[1] 11
$W
year genre_Non.Fiction tfidf_name_1 tfidf_name_10 tfidf_name_11
[1,] -1.603091 3.672535 1.950128 -1.43626 -1.184989
tfidf_name_12 tfidf_name_15 tfidf_name_17 tfidf_name_1936 tfidf_name_2
[1,] 0.7845823 1.055154 2.023146 -0.4562354 -0.8983199
tfidf_name_2.0 tfidf_name_20 tfidf_name_2016 tfidf_name_3 tfidf_name_30
[1,] 0.6441868 -0.3845356 2.316836 -0.7645377 0.5338076
tfidf_name_4 tfidf_name_451 tfidf_name_5 tfidf_name_52 tfidf_name_6
[1,] -1.58615 -0.2057711 0.6399975 -0.03103396 -2.447489
tfidf_name_6th tfidf_name_7 tfidf_name_700 tfidf_name_8 tfidf_name_a
[1,] 8.399278 0.9353007 -0.03854977 -2.511344 -5.69612
tfidf_name_about tfidf_name_acid tfidf_name_activity tfidf_name_adult
[1,] -0.2692211 1.657285 -0.7645377 -0.5636503
tfidf_name_adults tfidf_name_adventures tfidf_name_adversity
[1,] -0.5636503 0.3457586 -0.5273754
tfidf_name_after tfidf_name_again tfidf_name_agreements
[1,] -0.5680237 0.8277979 -1.578453
tfidf_name_alchemist tfidf_name_almost tfidf_name_america
[1,] 10.12624 -1.068412 -1.203286
tfidf_name_america's tfidf_name_american tfidf_name_americans
[1,] 0.1490499 6.308633 -0.09535955
tfidf_name_an tfidf_name_ancient tfidf_name_and tfidf_name_animals
[1,] 0.4980579 0.1490499 -2.040945 -1.41944
tfidf_name_antidote tfidf_name_antiracist tfidf_name_are tfidf_name_art
[1,] 0.7845823 0.6596925 -0.3225596 1.109151
tfidf_name_assassination tfidf_name_association tfidf_name_astrophysics
[1,] -0.5985014 8.399278 -0.827998
tfidf_name_at tfidf_name_awesome tfidf_name_b tfidf_name_back
[1,] -0.4562354 -0.5749855 -0.5273754 -2.991323
tfidf_name_badass tfidf_name_ball tfidf_name_ballad tfidf_name_barefoot
[1,] -0.5749855 0.2516445 1.820692 0.9473725
tfidf_name_be tfidf_name_bear tfidf_name_beasts tfidf_name_become
[1,] 0.680159 -1.429475 0.926818 0.2012908
tfidf_name_bed tfidf_name_beginner's tfidf_name_believing tfidf_name_belly
[1,] -1.033543 -0.8300828 0.2012908 -1.478175
tfidf_name_berlin tfidf_name_bill tfidf_name_blood tfidf_name_boat
[1,] 0.4061898 -1.203286 -1.245664 -0.4562354
tfidf_name_body tfidf_name_book tfidf_name_books tfidf_name_boxed
[1,] -0.7645377 -4.853441 -0.5636503 6.77096
tfidf_name_boys tfidf_name_brain tfidf_name_brain's tfidf_name_brawl
[1,] -0.4562354 -0.5131912 -0.5131912 -0.9788472
tfidf_name_brothers tfidf_name_brown tfidf_name_brush tfidf_name_building
[1,] 0.4325667 -1.429475 -0.8300828 -0.5273754
tfidf_name_burn tfidf_name_by tfidf_name_called tfidf_name_calligraphy
...
and 266 more lines.
Взглянем на остатки. Для этого пригодится уже знакомая функция augment() из пакета broom.
svm_res <- augment(svm_fit, new_data = books_test) |>
mutate(res = price - .pred) |>
select(price, .pred, res)
svm_reslibrary(gridExtra)
g1 <- svm_res |>
mutate(res = price - .pred) |>
ggplot(aes(res)) +
geom_histogram(fill = "steelblue", color = "white") +
theme_minimal()
g2 <- svm_res |>
ggplot(aes(price, .pred)) +
geom_jitter(color = "steelblue", alpha = 0.7) +
geom_abline(linetype = 2, color = "grey80", linewidth = 2) +
theme_minimal()
grid.arrange(g1, g2, nrow = 1)
Соберем метрики.
books_metrics <- metric_set(rmse, rsq, mae)
books_metrics(svm_res, truth = price, estimate = .pred)Также посмотрим, какие слова больше всего связаны с увеличением и с уменьшением цены.
svm_fit |>
tidy() |>
filter(term != "year") |>
filter(!str_detect(term, "genre")) |>
mutate(sign = case_when(estimate > 0 ~ "дороже",
.default = "дешевле"),
estimate = abs(estimate),
term = str_remove_all(term, "tfidf_name_")) |>
group_by(sign) |>
top_n(20, estimate) |>
ungroup() |>
ggplot(aes(x = estimate, y = fct_reorder(term, estimate),
fill = sign)) +
geom_col(show.legend = FALSE) +
scale_x_continuous(expand = c(0,0)) +
facet_wrap(~sign, scales = "free") +
labs(y = NULL,
title = "Связь слов с ценой книг") +
theme_minimal()
Любопытно: судя по нашему датасету, конституция США раздается на Амазоне бесплатно.